{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tool calling with mistral.rs\n",
    "\n",
    "Tool calling is a technique to enhance generation by providing the model with functions (tools) which it may call."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize the tools and their schemas"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from io import StringIO\n",
    "import json\n",
    "import sys\n",
    "from mistralrs import Runner, ToolChoice, Which, ChatCompletionRequest, Architecture\n",
    "\n",
    "tools = [\n",
    "    json.dumps(\n",
    "        {\n",
    "            \"type\": \"function\",\n",
    "            \"function\": {\n",
    "                \"name\": \"run_python\",\n",
    "                \"description\": \"Run some Python code\",\n",
    "                \"parameters\": {\n",
    "                    \"type\": \"string\",\n",
    "                    \"properties\": {\n",
    "                        \"code\": {\n",
    "                            \"type\": \"string\",\n",
    "                            \"description\": \"The Python code to evaluate. The return value whatever was printed out from `print`.\",\n",
    "                        },\n",
    "                    },\n",
    "                    \"required\": [\"code\"],\n",
    "                },\n",
    "            },\n",
    "        }\n",
    "    )\n",
    "]\n",
    "\n",
    "\n",
    "def custom_serializer(obj):\n",
    "    try:\n",
    "        res = json.dumps(obj)\n",
    "    except:\n",
    "        # Handle serializing, for example, an imported module\n",
    "        res = None\n",
    "    return res\n",
    "\n",
    "\n",
    "def run_python(code: str) -> str:\n",
    "    lcls = dict()\n",
    "    # No opening of files\n",
    "    glbls = {\"open\": None}\n",
    "\n",
    "    print(f\"Running:\\n```py\\n{code}\\n```\")\n",
    "\n",
    "    old_stdout = sys.stdout\n",
    "    out = StringIO()\n",
    "    sys.stdout = out\n",
    "    exec(code, glbls, lcls)\n",
    "    sys.stdout = old_stdout\n",
    "\n",
    "    return out.getvalue()\n",
    "\n",
    "\n",
    "functions = {\n",
    "    \"run_python\": run_python,\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set up the Runner and the initial messages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "messages = [\n",
    "    {\n",
    "        \"role\": \"user\",\n",
    "        \"content\": \"What is the value of the area of a circle with radius 4?\",\n",
    "    }\n",
    "]\n",
    "\n",
    "runner = Runner(\n",
    "    which=Which.Plain(\n",
    "        model_id=\"lamm-mit/Bioinspired-Llama-3-1-8B-128k-gamma\",\n",
    "        arch=Architecture.Llama,\n",
    "    ),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Ask the model and give it access with the tools\n",
    "\n",
    "The model will return the chosen tool, if it wants to call it. We just extract the first tool because this is a demo."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res = runner.send_chat_completion_request(\n",
    "    ChatCompletionRequest(\n",
    "        model=\"llama-3.1\",\n",
    "        messages=messages,\n",
    "        max_tokens=256,\n",
    "        presence_penalty=1.0,\n",
    "        top_p=0.1,\n",
    "        temperature=0.1,\n",
    "        tool_schemas=tools,\n",
    "        tool_choice=ToolChoice.Auto,\n",
    "    )\n",
    ")\n",
    "\n",
    "tool_called = res.choices[0].message.tool_calls[0].function"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run the tool, give the model what it said and the output of the tool"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if tool_called.name in functions:\n",
    "    args = json.loads(tool_called.arguments)\n",
    "    result = functions[tool_called.name](**args)\n",
    "    print(f\"Called tool `{tool_called.name}`\")\n",
    "\n",
    "    messages.append(\n",
    "        {\n",
    "            \"role\": \"assistant\",\n",
    "            \"content\": json.dumps({\"name\": tool_called.name, \"parameters\": args}),\n",
    "        }\n",
    "    )\n",
    "\n",
    "    messages.append({\"role\": \"tool\", \"content\": result})\n",
    "\n",
    "    res = runner.send_chat_completion_request(\n",
    "        ChatCompletionRequest(\n",
    "            model=\"llama-3.1\",\n",
    "            messages=messages,\n",
    "            max_tokens=256,\n",
    "            presence_penalty=1.0,\n",
    "            top_p=0.1,\n",
    "            temperature=0.1,\n",
    "            tool_schemas=tools,\n",
    "            tool_choice=ToolChoice.Auto,\n",
    "        )\n",
    "    )\n",
    "    # print(completion.usage)\n",
    "    print(res.choices[0].message.content)"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
